import os
import json
import sys
import re
import random
import numpy as np
from openai import AzureOpenAI
from tqdm import tqdm
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--count_personas', action='store_true', help='Output the number of personas')
parser.add_argument('--start', type=int, default=0, help='Start index for dataset slicing')
parser.add_argument('--end', type=int, default=None, help='End index for dataset slicing (inclusive)')
parser.add_argument('--output_dir', type=str, default='results', help='Directory to save per-job JSON outputs')
# NEW ARG: comma-separated list of repetition counts
parser.add_argument('--runs_list', type=str, default='20,40,60,80,100',
                    help='Comma-separated list indicating how many times to repeat the prediction for each datapoint (e.g., "20,40,60")')
# NEW ARGS FOR GMO EVALUATION
parser.add_argument('--gmo', action='store_true', help='Run GMO CTR/CPA evaluation mode (ads) instead of WebAES')
parser.add_argument('--dataset_dir', type=str, default='/path/to/ctr_dataset/',
                    help='Directory containing *.jsonl files for GMO evaluation')
parser.add_argument('--limit', type=int, default=None, help='Total number of samples to evaluate across all dataset files (evenly sampled)')
# Seed for deterministic sampling
parser.add_argument('--seed', type=int, default=42, help='Random seed to ensure consistent sampling across runs')
args = parser.parse_args()

# ---------------------------------------------------------------------------
# Set deterministic random seeds so that record sampling remains identical
# across different run counts and independent script invocations when the
# same --seed value is used.
# ---------------------------------------------------------------------------
random.seed(args.seed)
np.random.seed(args.seed)

# Parse runs_list into a list of ints and ensure they are positive
runs_list = [int(x) for x in args.runs_list.split(',') if x.strip()]
runs_list = [r for r in runs_list if r > 0]
if not runs_list:
    raise ValueError("--runs_list must contain at least one positive integer")

# After parsing args
output_dir = os.path.abspath(args.output_dir)
os.makedirs(output_dir, exist_ok=True)
args.output_dir = output_dir  # overwrite to absolute for consistency


# ---------------------------------------------------------------------------
# NOTE: Legacy WebAES (website likeability) code and prompts have been removed.
# This script now focuses solely on GMO CPA-percentile evaluation.
# ---------------------------------------------------------------------------

# (Image handling utilities removed – not required for CPA evaluation)

# Azure OpenAI Configuration
api_version = "2024-02-15-preview"
config_dict = {
    'api_key': "YOUR_OPENAI_API_KEY",
    'api_version': api_version,
    'azure_endpoint': "https://your-azure-openai-endpoint/"
}

# ----------------------------- GMO MODE HELPERS -----------------------------
SYSTEM_PROMPT_GMO = (
    "You are an expert digital advertisement analyst. "
    "Given the ad description below, predict the CPA percentile it will achieve (0-100).\n\n"
    "Return your response in exactly two lines:\n"
    "Answer: <0-100>\n"
    "Reason: <brief justification>"
)

def _sample_gmo_records(dataset_dir: str, total_limit: int | None):
    """Randomly sample records from each *.jsonl file in `dataset_dir`.

    If `total_limit` is provided, samples are taken as `total_limit // n_files` per file.
    Otherwise, all records from each file are returned.
    """
    file_paths = [os.path.join(dataset_dir, fp) for fp in os.listdir(dataset_dir) if fp.endswith('.jsonl')]
    if not file_paths:
        print(f"[ERROR] No .jsonl files found in {dataset_dir}", file=sys.stderr)
        sys.exit(1)

    random.shuffle(file_paths)  # shuffle to avoid ordering bias
    records = []
    per_file = None
    if total_limit is not None and total_limit > 0:
        per_file = max(1, total_limit // len(file_paths))

    for fp in file_paths:
        with open(fp, 'r', encoding='utf-8') as f_in:
            lines = f_in.readlines()
        if per_file is not None:
            chosen = random.sample(lines, min(per_file, len(lines)))
        else:
            chosen = lines
        for ln in chosen:
            try:
                rec = json.loads(ln)
                rec['_source_file'] = os.path.basename(fp)
                records.append(rec)
            except json.JSONDecodeError:
                continue  # skip malformed lines

    # If we overshot the limit due to rounding, trim back down
    if total_limit is not None and len(records) > total_limit:
        records = random.sample(records, total_limit)
    return records

def _verbalize_gmo(prompt: str) -> str:
    """Call Azure OpenAI with the GMO system prompt."""
    client = AzureOpenAI(
        api_key=config_dict['api_key'],
        api_version=config_dict['api_version'],
        azure_endpoint=config_dict['azure_endpoint'],
    )
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT_GMO},
        {"role": "user", "content": "/no_think" + prompt},
    ]
    resp = client.chat.completions.create(
        model='gpt-4o',
        messages=messages,
        max_tokens=350,
        temperature=0.85,
        n=1,
    )
    return resp.choices[0].message.content.strip()

def run_gmo_evaluation(args):
    """Main entry for GMO evaluation mode with support for multiple repetition counts (runs_list)."""

    # ---------------------------- Record Sampling ----------------------------
    records = _sample_gmo_records(args.dataset_dir, args.limit)

    # Apply optional slicing using --start and --end (1-based inclusive indices)
    slice_start = max(0, args.start)
    slice_end = args.end if args.end is not None else len(records) - 1
    slice_end = min(slice_end, len(records) - 1)
    records = records[slice_start : slice_end + 1]

    print(
        f"[INFO] Running GMO evaluation on {len(records)} sampled records (slice {slice_start}-{slice_end})."
    )

    # Ensure output directory exists
    os.makedirs(args.output_dir, exist_ok=True)

    # ----------------------- Iterate over repetition counts ------------------
    for n_runs in runs_list:
        print("\n" + "=" * 80)
        print(f"Running evaluation with {n_runs} repetitions per datapoint…")
        print("=" * 80)

        run_results = []

        # Construct output filename that includes run count so downstream merge works
        out_name = (
            f"gmo_results_runs{n_runs}_samples{args.limit or 'all'}_{slice_start}_{slice_end}.json"
        )
        out_path = os.path.join(args.output_dir, out_name)

        for rec in tqdm(records, desc=f"GMO Samples x{n_runs}"):
            ad_prompt = rec.get("prompt", "")
            ground_truth = rec.get("response")

            predictions: list[float] = []
            reasons: list[str] = []

            for _ in range(n_runs):
                resp_text = _verbalize_gmo(ad_prompt)

                num_match = re.search(r"(?i)answer[^0-9]{0,10}(\d{1,3}(?:\.\d+)?)", resp_text)
                score = float(num_match.group(1)) if num_match else None
                if score is not None:
                    score = max(0.0, min(100.0, score))

                if score is not None:
                    predictions.append(score)
                reasons.append(resp_text)

            mean_prediction = float(np.mean(predictions)) if predictions else None

            run_results.append(
                {
                    "prompt": ad_prompt,
                    "ground_truth": ground_truth,
                    "predictions": predictions,
                    "mean_prediction": mean_prediction,
                    "responses": reasons,
                    "source_file": rec.get("_source_file"),
                }
            )

            # Incremental write after each datapoint to protect against crashes
            try:
                with open(out_path, "w", encoding="utf-8") as f_inc:
                    json.dump(run_results, f_inc, indent=2)
            except Exception as e:
                print(f"[WARNING] Incremental save failed: {e}")

        # Final write for this n_runs
        try:
            with open(out_path, "w", encoding="utf-8") as f_final:
                json.dump(run_results, f_final, indent=2)
        except Exception as e:
            print(f"[ERROR] Final save failed for {out_path}: {e}")

        print(f"[INFO] GMO evaluation with {n_runs} runs complete. Results saved to {out_path}")

# ---------------------------------------------------------------------------
# Short-circuit: run GMO mode and exit (nothing further below).
# ---------------------------------------------------------------------------
if args.gmo:
    run_gmo_evaluation(args)
    sys.exit(0)

# No additional code below this point. 